1. Bug fix. 2. add fast long retention implement#25
1. Bug fix. 2. add fast long retention implement#25veya2ztn wants to merge 2 commits intosyncdoth:mainfrom
Conversation
…2. Add fix length seq arguement when the inputs is (addtional_token, pask_kv) 3. add fast retention implement when the sequence length >> D**2
| x = self.dropout_module(x) | ||
| return x | ||
|
|
||
|
|
| x = self.dropout_module(x) | ||
| return x | ||
|
|
||
|
|
| self.dropout_module = torch.nn.Dropout(dropout) | ||
| self.fc1 = nn.Linear(self.embed_dim, ffn_dim) | ||
| self.fc2 = nn.Linear(ffn_dim, self.embed_dim) | ||
| if subln: |
There was a problem hiding this comment.
I would like to keep the use_rms_norm. Also, I would prefer if-else instead of tertiary here. If you want tertiary, could you make it sth like:
norm_class = RMSNorm if use_rms_norm else LayerNorm
self.fnn_layernorm = norm_class(ffn_dim, eps=layernorm_eps) if subln else None
There was a problem hiding this comment.
The embed_dim should be replace by ffn_dim, I think
if subln:
if use_rms_norm:
self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = None
to
if subln:
if use_rms_norm:
self.ffn_layernorm = RMSNorm(ffn_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = None
|
|
||
| # multi-head | ||
| q, k, v = split_heads((q, k, v), B, T, self.num_heads) | ||
| k *= self.scaling # for scaled dot product |
There was a problem hiding this comment.
what's the reasoning for this change?
retnet/modeling_retnet.py
Outdated
| - "prev_key_value" # bsz * num_head * v_dim * qk_dim | ||
| - "scale" # (1 or bsz) * num_head * 1 * 1 | ||
| decay_mask, # 1 * num_head * chunk_size * chunk_size | ||
| decay_mask, # 1 * num_head * chunk_size * chunk_size |
There was a problem hiding this comment.
let's keep the spaces consistent.
| self.config = config | ||
| self.embed_dim = config.decoder_embed_dim | ||
| self.dropout_module = torch.nn.Dropout(config.dropout) | ||
| self.drop_path = DropPath(np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth]) if config.drop_path_rate > 0 else None |
There was a problem hiding this comment.
I prefer previous code. This one-liner is too long and breaks the 100 character limit.
| ways within their own init. | ||
| """ | ||
| pass | ||
| #pass |
| hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len)) | ||
| else: | ||
| slen = seq_length | ||
| if fixed_seq_len:slen=fixed_seq_len |
| forward_impl=forward_impl, | ||
| recurrent_chunk_size=recurrent_chunk_size, | ||
| retention_mask=retention_mask, | ||
| get_decay_scale=not self.training) |
There was a problem hiding this comment.
Why do we want decay scale during training?
There was a problem hiding this comment.
Below is an example for one parallel output and one recurrent output
with torch.inference_mode(): #<--almost equal to `torch.no_grad()`
model.eval() # <-- this disable dropout and batchnorm or other layers that behave differently during inference
out = model(old_inputs,
forward_impl='parallel' #<-- this line indicates parallel mode
use_cached=True,#<-- must have use_cached = True
**args,)
past_kv = out.past_key_values
model.train() # if want train later token, one need reactivate it here.
out = model(new_inputs,
forward_impl='recurrent' #<-- this line indicates recurrent mode
use_cached=True,#<-- must have use_cached = True for further recurrent mode
past_key_values=past_kv, # this line must
**args
)
If we don't have the model.eval() the recurrent mode fail to generate by take the scale=None.
However, the model.eval() will change the behavior of some layer.
There is no other important reason here, just for me convenience.
Basically, the goal is to generate a cache first and reuse it many times
cache --> task_1
cache --> task_2
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| ) | ||
| ) No newline at end of file |
There was a problem hiding this comment.
The file should ideally have a trailing whiteline (Some PEP standard)
|
Other than some formatting and refactoring issues, I love the fast-retention implementation! I was hoping to get into that. Thanks for your work! |
|
Will this be merged? |
|
There are some code styling issues and some things I don't understand fully. I think it's great to have its own branch for now. |
Cached the fixed retnet_rel_pos ( thus does not need generate runtimely)
add fast retention implement when the sequence length >> D**2.
See
https://github.com/veya2ztn/fast_retention5.1 I set
use_gludefaut to false, thus consistancy to old code.5.2 The layer norm setting in FFN seem wrong, the
self.embed_dim shouldbeffn_dimAnyway, I roll back to
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None